package com.wang.sort3;

import java.util.Comparator;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;
import com.wang.sort3.Demo08_GraphGenerator.*;

/**
 * @Author: along
 * @Create: 2021/4/24
 * @Description:
 */
public class Demo13_Prim {
    // undirected graph only
        public static class EdgeComparator implements Comparator<Edge> {

            @Override
            public int compare(Edge o1, Edge o2) {
                return o1.weight - o2.weight;
            }

        }

        public static Set<Edge> primMST(Graph graph) {
            // 解锁的边进入小根堆
            PriorityQueue<Edge> priorityQueue = new PriorityQueue<>(new EdgeComparator());
            HashSet<Node> set = new HashSet<>();
            Set<Edge> result = new HashSet<>(); // 依次挑选的的边在result里
            for (Node node : graph.nodes.values()) { // 随便挑了一个点
                // node 是开始点
                if (!set.contains(node)) {
                    set.add(node);
                    for (Edge edge : node.edges) { // 由一个点，解锁所有相连的边
                        priorityQueue.add(edge);
                    }
                    while (!priorityQueue.isEmpty()) {
                        Edge edge = priorityQueue.poll(); // 弹出解锁的边中，最小的边
                        Node toNode = edge.to; // 可能的一个新的点
                        if (!set.contains(toNode)) { // 不含有的时候，就是新的点,就加入解锁集和结果集
                            set.add(toNode);
                            result.add(edge);
                            for (Edge nextEdge : toNode.edges) {
                                priorityQueue.add(nextEdge);
                            }
                        }
                    }
                }
                //break;
            }
            return result;
        }

        // 请保证graph是连通图
        // graph[i][j]表示点i到点j的距离，如果是系统最大值代表无路
        // 返回值是最小连通图的路径之和
        public static int prim(int[][] graph) {
            int size = graph.length;
            int[] distances = new int[size];
            boolean[] visit = new boolean[size];
            visit[0] = true;
            for (int i = 0; i < size; i++) {
                distances[i] = graph[0][i];
            }
            int sum = 0;
            for (int i = 1; i < size; i++) {
                int minPath = Integer.MAX_VALUE;
                int minIndex = -1;
                for (int j = 0; j < size; j++) {
                    if (!visit[j] && distances[j] < minPath) {
                        minPath = distances[j];
                        minIndex = j;
                    }
                }
                if (minIndex == -1) {
                    return sum;
                }
                visit[minIndex] = true;
                sum += minPath;
                for (int j = 0; j < size; j++) {
                    if (!visit[j] && distances[j] > graph[minIndex][j]) {
                        distances[j] = graph[minIndex][j];
                    }
                }
            }
            return sum;
        }

        public static void main(String[] args) {
            System.out.println("hello world!");
        }

}
